"""
Module containing the encoders.
"""
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F


class EncoderUnit(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(*args, **kwargs)
        self.ac = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.ac(x)
        return x


class DecoderUnit(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.conv = nn.ConvTranspose2d(*args, **kwargs)
        self.ac = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.ac(x)
        return x


class Reshape(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(self.shape)


def reparameterize(mu, log_var):
    e = torch.randn_like(mu)
    sigma = torch.exp(log_var / 2)
    return mu + e * sigma


class VAE(nn.Module):
    def __init__(self, img_size,
                 latent_dim=10):
        super(VAE, self).__init__()
        # Layer parameters
        hid_channels = 32
        hidden_dim = 128
        self.latent_dim = latent_dim
        self.img_size = img_size

        n_chan = self.img_size[0]
        cnn_kwargs = dict(stride=2, padding=1, kernel_size=4)

        self.encoders = nn.Sequential(*
                                      ([EncoderUnit(in_channels=i, out_channels=hid_channels, **cnn_kwargs)
                                        for i in [n_chan, hid_channels, hid_channels, hid_channels]]
                                       + [Reshape((-1, hid_channels * 16)),
                                          nn.Linear(hid_channels * 16, hidden_dim), nn.ReLU(),
                                          nn.Linear(hidden_dim, latent_dim * 2)]))

        cnn_kwargs = dict(stride=2, padding=1, kernel_size=4)
        self.decoder = nn.Sequential(*
                                     ([nn.Linear(latent_dim, hid_channels * 16), nn.LeakyReLU(),
                                       Reshape((-1, hid_channels, 4, 4))] +
                                      [DecoderUnit(in_channels=hid_channels, out_channels=hid_channels,
                                                   **cnn_kwargs) for i in range(3)]) +
                                     [nn.ConvTranspose2d(in_channels=hid_channels, out_channels=n_chan, **cnn_kwargs)])

    def encoder(self, x):
        mu_logvar = self.encoders(x)
        mu, logvar = mu_logvar.view(-1, self.latent_dim, 2).unbind(-1)
        return mu, logvar

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = reparameterize(mu, logvar)

        recon = self.decoder(z)
        return recon, (mu, logvar), z


def get_preds(model, dl):
    model.eval()
    preds = []
    targets = []
    device = next(model.parameters()).device
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).to(device)
            mu, _ = model.encoder(img)
            preds.append(mu)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, targets
